Skip to content

InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: py/init-prior-uniform
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jul 10, 2025

Part 1: Adding hasvalue and getvalue to AbstractPPL
Part 2: Removing hasvalue and getvalue from DynamicPPL
Part 3: Introducing InitContext and init!!

This is part 4/N of #967.


In Part 3 we introduced InitContext. This PR makes use of the functionality in there to replace a bunch of code that no longer needs to exist:

  • setval_and_resample! followed by model evaluation: This process was used for predict and returned, to manually store certain values in the VarInfo, which would be used in the subsequent model evaluation. We can now do this in a single step using ParamsInit.
  • initialize_values!!: very similar to the above. It would manually set values inside the varinfo, and then it would trigger an extra model evaluation to update the logp field. Again, this is directly replaced with ParamsInit.
  • evaluate_and_sample!!: direct one-to-one replacement with init!!.

There is one API change associated with this: the initial_params kwarg to sample must now be an AbstractInitStrategy. It's still optional (it will usually default to PriorInit). However, there are two implications:

  • initial_params cannot be a vector of parameters anymore. It must be ParamsInit(::NamedTuple) OR ParamsInit(::AbstractDict{VarName}).
  • Because ParamsInit expects values in unlinked space, initial_params must always be specified in unlinked space. Previously, initial_params would have to be specified in a way that matched the linking status of the underlying varinfo.

I consider both of these to be a major win for clarity. (One might argue that vectors are more convenient. Sure, you can get a vector with vi[:], but now you can just do values_as(vi, Dict{VarName,Any}) instead.)

Closes

Closes #774
Closes #797
Closes #983
Closes TuringLang/Turing.jl#2476
Closes TuringLang/Turing.jl#1775

Copy link
Contributor

github-actions bot commented Jul 10, 2025

Benchmark Report for Commit f6dd1d5

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  8.9 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                669.6 |                43.4 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                437.9 |                60.9 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1216.2 |                31.3 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6754.3 |                29.1 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1075.2 |                40.0 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1036.3 |                 4.2 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5977.1 |                 3.9 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1029.5 |                 8.8 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              67599.5 |                 3.6 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               9197.2 |                 9.6 |
|               Dynamic |        10 |    mooncake |             typed |   true |                148.2 |                12.8 |
|              Submodel |         1 |    mooncake |             typed |   true |                 13.5 |                 4.9 |
|                   LDA |        12 | reversediff |             typed |   true |               1166.3 |                 3.5 |

Comment on lines 126 to 134
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = last(
DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
),
)
Copy link
Member Author

@penelopeysm penelopeysm Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that, if the chain does not store varnames inside its info field, chain_sample_to_varname_dict will fail.

I don't consider this to be a problem right now because every chain obtained via Turing's sample() will contain varnames:

https://github.com/TuringLang/Turing.jl/blob/1aa95ac91a115569c742bab74f7b751ed1450309/src/mcmc/Inference.jl#L288-L290

So this is only a problem if you manually construct a chain and try to call predict on it, which I think is a highly unlikely workflow (and I'm happy to wait for people to complain if it fails). There are a few places in DynamicPPL's test suite where this does actually happen. I fixed them all by manually adding the varname dictionary.

However, it's obviously ugly. The only good way around this is, as I suggested before, to rework MCMCChains.jl. (See here for the implementation of the corresponding functionality in FlexiChains.)

@penelopeysm penelopeysm changed the title Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values Jul 10, 2025
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from 025aa8b to b55c1e1 Compare July 10, 2025 14:24
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 5 times, most recently from b72c3bf to 92d3542 Compare July 10, 2025 15:57
@penelopeysm penelopeysm mentioned this pull request Jul 10, 2025
22 tasks
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 4 times, most recently from 7438b23 to d55d378 Compare July 10, 2025 16:56
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 3 times, most recently from 12d93e5 to 7a8e7e3 Compare July 10, 2025 17:47
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 2 times, most recently from 1d8bceb to 2edcd10 Compare July 20, 2025 00:59
Copy link
Contributor

DynamicPPL.jl documentation for PR #984 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR984/

Copy link

codecov bot commented Jul 20, 2025

Codecov Report

❌ Patch coverage is 86.74699% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.57%. Comparing base (fd78d42) to head (23cafe0).

Files with missing lines Patch % Lines
src/simple_varinfo.jl 40.00% 6 Missing ⚠️
src/test_utils/contexts.jl 85.71% 4 Missing ⚠️
src/test_utils/model_interface.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@                    Coverage Diff                    @@
##           py/init-prior-uniform     #984      +/-   ##
=========================================================
- Coverage                  82.41%   80.57%   -1.85%     
=========================================================
  Files                         39       39              
  Lines                       3992     3927      -65     
=========================================================
- Hits                        3290     3164     -126     
- Misses                       702      763      +61     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 3a16f9c to 4c96020 Compare July 26, 2025 18:47
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from 9c07727 to ef038c6 Compare August 5, 2025 11:52
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 4c96020 to 7b4c5fa Compare August 5, 2025 11:52
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from ef038c6 to 5961ca9 Compare August 8, 2025 10:15
@penelopeysm penelopeysm mentioned this pull request Aug 8, 2025
8 tasks
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from 0656487 to fd78d42 Compare August 8, 2025 10:21
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 7b4c5fa to 23cafe0 Compare August 8, 2025 10:22
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from fd78d42 to bfcdbb9 Compare August 10, 2025 13:33
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 23cafe0 to 8587eb7 Compare August 10, 2025 13:34
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from bfcdbb9 to ab3e8da Compare August 10, 2025 13:44
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 8587eb7 to a6a42bd Compare August 10, 2025 13:45
Comment on lines +31 to +37
function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi)
@warn(
"Encountered SamplingContext->InitContext. This method will be removed in the next PR.",
)
# just pretend the `InitContext` isn't there for now.
return assume(rng, sampler, right, vn, vi)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was introduced to stop some test from failing (I think it was JETExt, but honestly, this PR was a month ago, so I already forgot exactly which one). It will indeed be removed together with SamplingContext.

Comment on lines -1228 to +1199
"""
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
"""
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
)
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do params_varinfo
vi = deepcopy(varinfo)
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
model(rng, vi)
return vi
end
end
# Implemented & documented in DynamicPPLMCMCChainsExt
function predict end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was discussed at one of the meetings and we decided we didn't care enough about the predict method on vectors of varinfos. It's currently bugged because varinfo is always unlinked, but params_varinfo might be linked, and if it is, it will give wrong results because it sets a linked value into an unlinked varinfo. See #983.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant